#%% 
import shap  # https://github.com/slundberg/shap
import shapreg  # https://github.com/iancovert/shapley-regression
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import argparse
import time
parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=7e-4)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--num_samples', type=int, default=32)
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--pair_sampling', action='store_true')
args = parser.parse_args()
# Load and split data
X_train, X_test, Y_train, Y_test = train_test_split(
    *shap.datasets.adult(), test_size=0.2, random_state=7)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=0)

# Data scaling
num_features = X_train.shape[1]
feature_names = X_train.columns.tolist()
ss = StandardScaler()
ss.fit(X_train)
X_train = ss.transform(X_train)
X_val = ss.transform(X_val)
X_test = ss.transform(X_test)


#%% Load Model
#%% load model
import pickle
import torch
from fastshap import Surrogate, FastSHAP
from simshap.simshap_sampling import SimSHAPSampling
import torch.nn as nn
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import sys
sys.path.append('..')
from models import SimSHAPTabular
device = torch.device('cuda')
with open('census model.pkl', 'rb') as f:
        model = pickle.load(f)
surr = torch.load('census surrogate.pt').to(device)
surrogate = Surrogate(surr, num_features)


#%% Wider Net
explainer = SimSHAPTabular(in_dim=num_features, hidden_dim=1024, out_dim=2).to(device)

# Set up FastSHAP object
simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)
# Train
simshap.train(
    X_train,
    X_val[:100],
    batch_size=2048,
    num_samples=32,
    max_epochs=1000,
    lr=1e-2,  
    bar=False,
    validation_samples=1024,
    verbose=True, 
    lookback=20,
    lr_factor=0.5,
    accum_iter=1)
# Save explainer
explainer.cpu()
torch.save(explainer, 'census simshap wide.pt')
explainer.to(device)

#%% Deeper Net
class Net(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Net, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.linear = nn.Linear(in_dim, 128)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(128, 128)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.Linear(128, 128)
        self.relu3 = nn.ReLU()
        self.linear4 = nn.Linear(128, 128)
        self.relu4 = nn.ReLU()
        self.linear5 = nn.Linear(128, out_dim*in_dim)

    def forward(self, x):
        out = self.linear(x)
        out = self.relu1(out)
        out = self.linear2(out)
        out = self.relu2(out)
        out = self.linear3(out)
        out = self.relu3(out)
        out = self.linear4(out)
        out = self.relu4(out)
        out = self.linear5(out)
        out = out.view(out.size(0), self.out_dim, self.in_dim)
        return out
    
explainer = Net(in_dim=num_features, out_dim=2).to(device)

simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)
# Train
simshap.train(
    X_train,
    X_val[:100],
    batch_size=2048,
    num_samples=32,
    max_epochs=1000,
    lr=1.5e-2,  
    bar=False,
    validation_samples=1024,
    verbose=True, 
    lookback=20,
    lr_factor=0.5,
    accum_iter=1)
# Save explainer
explainer.cpu()
torch.save(explainer, 'census simshap deep.pt')
explainer.to(device)
#%% ablation study
explainer = SimSHAPTabular(in_dim=num_features, hidden_dim=64, out_dim=2).to(device)
simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)
# Train
simshap.train(
    X_train,
    X_val[:100],
    batch_size=args.batch_size,
    num_samples=args.num_samples,
    max_epochs=args.epochs,
    paired_sampling=args.pair_sampling,
    lr=args.lr,  
    bar=False,
    validation_samples=1024,
    verbose=True, 
    lookback=20,
    lr_factor=0.5,
    accum_iter=1)

explainer.cpu()
torch.save(explainer, 'census simshap ablation.pt')
explainer.to(device)

